import coremltools as ct
import numpy as np
from coremltools.converters.mil import Builder as mb

target = ct.target.iOS15

x_shape = (2, 2, 3, 2)
grid_shape = (2, 3, 2, 2)


@mb.program(input_specs=[mb.TensorSpec(shape=x_shape), mb.TensorSpec(shape=grid_shape)], opset_version=target)
def prog(x, grid):
    sampling = mb.const(name="sampling_mode", val="bilinear")
    padding_mode = mb.const(name="pmode", val="reflection")
    pad = mb.const(name="pval", val=np.float32(0))
    coord_mode = mb.const(name="coord_mode", val="normalized_minus_one_to_one")
    align_corners = mb.const(name="align_corners", val=False)
    z = mb.resample(
        x=x,
        coordinates=grid,
        sampling_mode=sampling,
        padding_mode=padding_mode,
        padding_value=pad,
        coordinates_mode=coord_mode,
        align_corners=align_corners,
    )

    return z


# print(prog)

# Convert to ML program
m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32)

# spec = m.get_spec()
# print(spec)

m.save("GridSample.mlpackage")
# construct MLModel with compute_units=ComputeUnit.CPU and run predict
m_cpu = ct.models.MLModel("GridSample.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY)
m_all = ct.models.MLModel("GridSample.mlpackage", compute_units=ct.ComputeUnit.ALL)

# GridSampleTest.test_grid_sample_20_4D_bilinear_reflection_no_align_corners
# ORT produces different output for this test. ORT output is generated by pytorch
x = (
    np.array(
        [
            -0.173652,
            -1.513725,
            -0.704586,
            -1.952375,
            -0.699404,
            -0.806298,
            1.640852,
            -0.138969,
            -0.695411,
            -1.352111,
            0.568797,
            -0.564294,
            -0.056468,
            0.641604,
            -0.438370,
            0.450167,
            -1.091401,
            1.669729,
            -0.908544,
            0.244467,
            0.172109,
            1.156741,
            -0.617128,
            1.155460,
        ]
    )
    .astype(np.float32)
    .reshape(x_shape)
)

grid = (
    np.array(
        [
            0.252250,
            -0.151452,
            0.824706,
            -0.588292,
            -0.591147,
            -0.155082,
            -0.732938,
            0.457493,
            -0.439559,
            0.492330,
            0.696447,
            0.700722,
            -0.220298,
            0.654884,
            -0.635434,
            -1.195619,
            -0.114204,
            -0.870080,
            -0.929674,
            0.305035,
            1.025429,
            -0.472240,
            -0.067881,
            -0.869393,
        ]
    )
    .astype(np.float32)
    .reshape(grid_shape)
)


print(m_cpu.predict({"x": x, "grid": grid}))
print(m_all.predict({"x": x, "grid": grid}))
